import math
import logging
from functools import partial
import scipy.special
import scipy.stats
import numpy as np
from models import (
    Metric, QuerySet, ELO_INITIAL_SCORE, ELO_SCALE_FACTOR,
    ELO_MAX_UPDATES, ELO_CONVERGENCE_THRESHOLD,
    GAMMA_PRIOR_SHAPE, GAMMA_PRIOR_RATE,
    skill_to_elo, elo_to_skill, BETA_PRIOR_ALPHA, BETA_PRIOR_BETA
)
from copy import deepcopy
import time

logger = logging.getLogger(__name__)

def update_bayesian_elo(metric, sessions):
    """Implements an iterative procedure for computing Elo scores.

    This algorithm is based on Caron & Doucet's (2010) Bayesian interpretation of an algorithm by
    Hunter (2004). However, where Caron & Doucet view it as expectation maximization algorithm for
    computing a MAP estimate, we here go a step further and interpret it as mean-field variational
    inference steps. This allows us to additionally compute uncercainty estimates.

    Args:
        metric: A Metric object containing the state
        sessions: A QuerySet object containing the sessions to process

    References:
        Caron & Doucet (2010), Efficient Bayesian Inference for the Bradley-Terry Model
        https://www.stats.ox.ac.uk/~doucet/caron_doucet_bayesianbradleyterry.pdf
    """

    # Parameters of a Gamma prior over skill values. The parameter "b" only determines the scale of
    # skill values. It should be of some numerical interest but should otherwise have no influence
    # on resulting Elo scores.
    start = time.time()

    a, b = GAMMA_PRIOR_SHAPE, GAMMA_PRIOR_RATE
    alpha, beta = BETA_PRIOR_ALPHA, BETA_PRIOR_BETA

    # Elo scores.
    scores = metric.state.get('scores', {})
    qualities = metric.state.get('qualities', {})

    # List of methods indicates indices in the win matrix.
    methods = metric.state.get('methods', [])
    method_to_index = dict(zip(methods, range(len(methods))))

    raters = metric.state.get('raters', [])
    rater_to_index = dict(zip(raters, range(len(raters))))

    # Pairwise win matrix.
    wins = np.zeros((0, 0, 0))
    # Makes it so that the mode of the gamma distribution corresponds to an Elo score of
    # ELO_INITIAL_SCORE.
    elo_offset = math.log10((a - 1) / b) * ELO_SCALE_FACTOR
    elo_offset = ELO_INITIAL_SCORE - elo_offset

    s2e = partial(skill_to_elo, offset=elo_offset)
    e2s = partial(elo_to_skill, offset=elo_offset)

    # Collect and update win statistics.
    for session in sessions:
        for slate in session.slates.all():
            ratings = slate.ratings.all()
            count = len(ratings)

            if count != 2:
                logger.error(
                    'Slate {} in session {} has {} ratings instead of 2.'.format(
                        slate.id, session.id, count
                    )
                )
                continue

            rater = session.rater
            method_i = ratings[0].stimulus.name
            method_j = ratings[1].stimulus.name
            s_i, s_j = ratings[0].score, ratings[1].score

            # Add new rater if missing
            if rater not in rater_to_index:
                index_rater = len(raters)
                raters.append(rater)
                rater_to_index[rater] = index_rater
                # Expand 3rd dim
                wins = np.pad(wins, ((0, 0), (0, 0), (0, 1)))
            else:
                index_rater = rater_to_index[rater]

            # Add new method_i if missing
            if method_i not in method_to_index:
                index_i = len(methods)
                methods.append(method_i)
                method_to_index[method_i] = index_i
                wins = np.pad(wins, ((0, 1), (0, 0), (0, 0)))
                wins = np.pad(wins, ((0, 0), (0, 1), (0, 0)))  # Also expand dim 1 for symmetry
            else:
                index_i = method_to_index[method_i]

            # Add new method_j if missing
            if method_j not in method_to_index:
                index_j = len(methods)
                methods.append(method_j)
                method_to_index[method_j] = index_j
                wins = np.pad(wins, ((0, 1), (0, 0), (0, 0)))
                wins = np.pad(wins, ((0, 0), (0, 1), (0, 0)))  # Expand symmetrically
            else:
                index_j = method_to_index[method_j]

            s_i = ratings[0].score
            s_j = ratings[1].score

            if s_i > s_j:
                wins[index_i][index_j][index_rater] += 1
            elif s_j > s_i:
                wins[index_j][index_i][index_rater] += 1
            else:
                # We don't model ties explicitly. Instead, we consider the expected outcome of a
                # forced choice.
                wins[index_i][index_j][index_rater] += 0.5
                wins[index_j][index_i][index_rater] += 0.5

    #sort methods by name
    metric.state['methods'] = methods
    # wins = wins/100

    # wins = wins/wins.sum()
    metric.state['wins'] = wins.tolist()

    n = len(methods)
    R = len(raters)
    gamma = np.zeros((n, n, R))

    # Update scores.
    for t in range(ELO_MAX_UPDATES * sessions.count()):
        max_elo_change = 0

        elo_scores_t = np.array([scores.get(m, {}).get('value', ELO_INITIAL_SCORE) for m in methods])
        skill_scores_t = e2s(elo_scores_t)
        qualities_t = np.array([qualities.get(r, {}).get('value', 0.5) for r in raters])

        # Vectorized gamma computation
        # Create meshgrid for all i,j pairs
        i_indices, j_indices = np.meshgrid(np.arange(n), np.arange(n), indexing='ij')
        # Create 3D array of skill ratios
        skill_ratios = skill_scores_t[i_indices] / (skill_scores_t[i_indices] + skill_scores_t[j_indices])
        # Expand qualities to match dimensions
        qualities_expanded = qualities_t[np.newaxis, np.newaxis, :]
        # Compute bt_win_prob for all i,j,r
        bt_win_prob = qualities_expanded * skill_ratios[..., np.newaxis]
        # Compute gamma values
        gamma = bt_win_prob / (bt_win_prob + (1 - qualities_expanded) / 2)
        # Set diagonal to 0 for each rater
        for r in range(R):
            np.fill_diagonal(gamma[:, :, r], 0)

        # Vectorized score updates
        # Compute nominator and denominator for all methods
        nominator = a - 1 + np.sum(wins * gamma, axis=(1, 2))
        denominator = b + np.sum(
            (wins * gamma + np.transpose(wins, (1, 0, 2)) * np.transpose(gamma, (1, 0, 2))) / 
            (skill_scores_t[:, np.newaxis, np.newaxis] + skill_scores_t[np.newaxis, :, np.newaxis]),
            axis=(1, 2)
        )
        
        # Update all scores at once
        skill_scores_new = nominator / denominator
        elo_scores_new = s2e(skill_scores_new)
        
        # Update max_elo_change
        max_elo_change = np.max(np.abs(elo_scores_new - elo_scores_t))
        
        # Update scores dictionary
        for i, method in enumerate(methods):
            scores[method] = {'value': elo_scores_new[i]}
        
        # Vectorized quality updates
        # Create mask for i < j pairs
        i_less_j = np.triu_indices(n, k=1)
        # Compute nominator and denominator for all raters
        nominator_q = alpha - 1 + np.sum(
            (wins[i_less_j] * gamma[i_less_j] + 
             wins[i_less_j[1], i_less_j[0]] * gamma[i_less_j[1], i_less_j[0]]),
            axis=0
        )
        denominator_q = alpha + beta - 2 + np.sum(
            wins[i_less_j] + wins[i_less_j[1], i_less_j[0]],
            axis=0
        )
        
        # Update all qualities at once
        qualities_new = nominator_q / denominator_q
        for r, rater in enumerate(raters):
            qualities[rater] = {"value": qualities_new[r]}

        # Check for convergence
        if max_elo_change < ELO_CONVERGENCE_THRESHOLD:
            break


    end = time.time()
    delta_time = end - start
    
    metric.state['rater_qualities'] = {r: qualities.get(r, {}).get('value', np.array(0.5)).item() for r in raters}
    
    # Estimate uncertainty.
    for i in range(n):
        method_i = methods[i]
        # skill_i = e2s(scores.get(method_i, {}).get('value', 1))
        # a_i = a - 1 + np.sum(wins * gamma, axis=(1, 2))[i].item()

        # Evi = np.sum(wins * gamma, axis=(1, 2))[i].item()/100
        # Ezi = np.sum(
        #     (wins * gamma + np.transpose(wins, (1, 0, 2)) * np.transpose(gamma, (1, 0, 2))) / 
        #     (skill_scores_t[:, np.newaxis, np.newaxis] + skill_scores_new[np.newaxis, :, np.newaxis]),
        #     axis=(1, 2)
        # )[i].item()*100
        # print(gamma[0])
        # print((wins * gamma + np.transpose(wins, (1, 0, 2)) * np.transpose(gamma, (1, 0, 2)))[0])
        # print((wins+np.transpose(wins, (1, 0, 2)))[0])
        # asd
        # print(Evi, 1/(b+Ezi), Ezi)

        # a_i = a - 1 + wins.sum(axis=(1, 2))[i]
        # print(a_i, skill_scores_t[i]/a_i)

        elo_scores_t = np.array([scores.get(m, {}).get('value', ELO_INITIAL_SCORE) for m in methods])
        skill_scores_t = e2s(elo_scores_t)

        qualities_t = np.array([qualities.get(r, {}).get('value', 0.5) for r in raters])

        # Vectorized gamma computation
        # Create meshgrid for all i,j pairs
        i_indices, j_indices = np.meshgrid(np.arange(n), np.arange(n), indexing='ij')
        # Create 3D array of skill ratios
        skill_ratios = skill_scores_t[i_indices] / (skill_scores_t[i_indices] + skill_scores_t[j_indices])
        # Expand qualities to match dimensions
        qualities_expanded = qualities_t[np.newaxis, np.newaxis, :]
        # Compute bt_win_prob for all i,j,r
        bt_win_prob = qualities_expanded * skill_ratios[..., np.newaxis]
        # Compute gamma values
        gamma = bt_win_prob / (bt_win_prob + (1 - qualities_expanded) / 2)
        # Set diagonal to 0 for each rater
        for r in range(R):
            np.fill_diagonal(gamma[:, :, r], 0)

        nominator = a - 1 + np.sum(wins * gamma, axis=(1, 2))
        denominator = b + np.sum(
            (wins * gamma + np.transpose(wins, (1, 0, 2)) * np.transpose(gamma, (1, 0, 2))) / 
            (skill_scores_t[:, np.newaxis, np.newaxis] + skill_scores_t[np.newaxis, :, np.newaxis]),
            axis=(1, 2)
        )

        # Gamma distribution approximating the posterior distribution over the score.
        # The scale is the reciprocal of the rate `b_i = a_i / skill_i`.
        
        gamma_dist = scipy.stats.gamma(a=nominator[i], scale=1/denominator[i])
        percentiles = gamma_dist.ppf([0.005, 0.025, 0.05, 0.5, 0.95, 0.975, 0.995])

        # Convert percentiles to Elo scale.
        scores[method_i]['p005'] = s2e(percentiles[0])
        scores[method_i]['p025'] = s2e(percentiles[0])
        scores[method_i]['p05'] = s2e(percentiles[1])
        scores[method_i]['median'] = s2e(percentiles[2])
        scores[method_i]['p95'] = s2e(percentiles[3])
        scores[method_i]['p975'] = s2e(percentiles[4])
        scores[method_i]['p995'] = s2e(percentiles[5])

    metric.state['scores'] = scores 
    return delta_time